# Defines the fit-estimate routines

#' Minimum estimated_partition class
#'
#' @param partition partition
#' @param cell_stats cell_stats
#' @param ... Additional arguments
#'
#' @return object of class estimated_partition
#' @export
estimated_partition <- function(partition, cell_stats, ...) {
  extra_params = list(...)
  
  if(!"m_mode" %in% names(extra_params)) extra_params$m_mode = DS.SINGLE
  if(!"M" %in% names(extra_params)) extra_params$M = 1
  
  return(structure(c(list(partition=partition, cell_stats=cell_stats), extra_params), 
                   class=c("estimated_partition")))
}

# Inherited params: bump_ratio, max_splits, max_cells, bucket_min_n, bucket_min_d_var, split_check_fn, breaks_per_dim, verbosity, partition_i
#           y, X, d, min_size
# Params expanded here: bump_samples, pr_cl, cv_folds
# Hidden params:
# - @param honest Whether to use the emse_hat or mse_hat. Use emse for outcome mean. For treatment effect, 
#               use if want predictive accuracy, but if only want to identify true dimensions of heterogeneity 
#               then don't use.
# Hidden returns: honest (don't advertise); m_mode, M (user knows this and I don't expose constants yet); has_d (user knows and not important)
#' Fit Grid Partition and estimate cell stats
#' 
#' Split the data, one one side train/fit the partition using \code{\link{fit_partition}} and then on the other estimate subgroup effects.
#' 
#' @inheritSection fit_partition Multiple estimates
#'
#' @param cv_folds Number of CV Folds or a vector of foldids. 
#'                If m_mode==DS.MULTI_SAMPLE, then a list with foldids per Dataset.
#'                Each must be over the training sample
#' @param tr_split Number between 0 and 1 or vector of indexes. If Multiple effect #3 and using vector then pass in list of vectors. 
#' @param ctrl_method Method for determining additional control variables. Empty ("") for nothing, "all",  "LassoCV", or "RF"
#' @param pr_cl Default NULL. Parallel cluster. Used for: \enumerate{
#'                \item CVing the optimal lambda,
#'                \item fitting full tree (at each split going across dimensions),
#'                \item fitting trees over the bumped samples
#'                \item for importance weights to estimate models over limited X domains
#'              }
#' @param alpha Significance threshold for confidence intervals. Default=0.05
#' @param bump_samples Number of bump bootstraps (default 0), or list of such length where each items is a bootstrap sample.
#'                     If m_mode==DS.MULTI_SAMPLE then each item is a sublist with such bootstrap samples over each dataset.
#'                     Each bootstrap sample must be over the train split of the data
#' @param importance_type Options:
#'                        single - (smart) redo full fitting removing each possible dimension
#'                        interaction - (smart) redo full fitting removing each pair of dimensions
#'                         "" - Nothing
#' @inheritParams fit_partition
#'
#' @return An object with class \code{"estimated_partition"}.
#'         \item{partition}{\code{\link{grid_partition}} obj defining cuts}
#'         \item{cell_stats}{Cell stats from \code{\link{est_cell_stats}$stats} on the est sample}
#'         \item{importance_weights}{Importance weights for each feature}
#'         \item{interaction_weights}{Interaction weights for each pair of features}
#'         \item{lambda}{lambda used}
#'         \item{is_obj_val_seq}{In-sample objective function values for sequence of partitions}
#'         \item{complexity_seq}{Complexity #s (# cells-1) for sequence of partitions}
#'         \item{partition_i}{Index of Partition selected in sequence}
#'         \item{split_seq}{Sequence of \code{partition_splits}s. Note that split i corresponds to partition i+1}
#'         \item{index_tr}{Index of training sample (we might have generated it). Order N}
#'         \item{cv_foldid}{CV foldids for the training sample (Size of N_tr)}
#'         \item{varnames}{varnames (or c("X1", "X2",...) if X doesn't have colnames)}
#'         \item{est_plan}{Fitted \code{\link{EstimatorPlan}} used.}
#'         \item{full_stat_df}{Full sample average stats from \code{\link{est_full_stats}}}
#'         
#' @export
fit_estimate_partition <- function(y, X, d=NULL, tr_split = 0.5, max_splits=Inf, max_cells=Inf, min_size=3, 
                                   cv_folds=5, potential_lambdas=NULL, partition_i=NA,
                                   verbosity=0, breaks_per_dim=NULL, 
                                   bucket_min_n=NA, bucket_min_d_var=FALSE, 
                                   ctrl_method="", pr_cl=NULL, alpha=0.05, bump_samples=0, bump_ratio=1,
                                   importance_type="", ...) {
  extra_params = list(...)
  honest=FALSE
  if(length(extra_params)>0) {
    if("honest" %in% names(extra_params)) honest = extra_params[['honest']]
    #don't check all extra names here as others are passed to another public function that takes hidden arguments
  }
  assert_that(!honest, msg="Honest eval metric not currently implimented.")
  honest=FALSE
  list[M, m_mode, N_tr, K] = get_sample_type(y, X, d, checks=TRUE)
  if(is_sep_sample(X) && length(tr_split)>1) {
    assert_that(is.list(tr_split) && length(tr_split)==M, msg="When separate sample & length(tr_split)>1, need is.list(tr_split) && length(tr_split)==M")
  }
  X = ensure_good_X(X)
  X = update_names_m(X)
  X_range = get_X_range(X)
  if(!is.list(breaks_per_dim)) {
    breaks_per_dim = get_quantile_breaks(X, X_range, g=breaks_per_dim)
  }
  
  dim_cat = which(get_dim_cat_m(X))
  
  index_tr = expand_tr_split_info(tr_split, N_tr, m_mode)
  
  list[y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est] = split_sample_m(y, X, d, index_tr)
  
  
  #Setup est_plan
  if(is.character(ctrl_method)) {
    if(ctrl_method=="") {
      est_plan = gen_simple_est_plan(has_d=!is.null(d))
    } 
    else {
      assert_that(ctrl_method %in% c("all", "LassoCV", "RF"), !is.null(d), msg="String ctrl_method sepecified but not undestood or d is null.")
      est_plan = if(ctrl_method=="RF") grid_rf() else lm_est(lasso = ctrl_method=="LassoCV")
    }
  }
  else {
    assert_that(inherits(ctrl_method, "estimator_plan"), msg="Non-string ctrl_method specified that doesn't inherit from estimator_plan")
    est_plan = ctrl_method
  } 
  if(is_sep_estimators(m_mode)) {
    est_plan1 = est_plan
    est_plan = list()
    for(m in 1:M) est_plan[[m]] = est_plan1
  }
  
  X_range = get_X_range(X)
  t0 = Sys.time()
  main_params = list(X=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es,
                     X_range=X_range, max_splits=max_splits, max_cells=max_cells, min_size=min_size, 
                     cv_folds=cv_folds, potential_lambdas=potential_lambdas, partition_i=partition_i, 
                     verbosity=verbosity, breaks_per_dim=breaks_per_dim,  bucket_min_n=bucket_min_n, 
                     bucket_min_d_var=bucket_min_d_var, honest=honest,
                     alpha=alpha, bump_samples=bump_samples, pr_cl=pr_cl, 
                     bump_ratio=bump_ratio, M=M, m_mode=m_mode,
                     dim_cat=dim_cat, N_est=N_est, est_plan=est_plan)
  main_ret = do.call(fit_estimate_partition_int, c(main_params, extra_params))
  list[partition, is_obj_val_seq, complexity_seq, partition_i, partition_seq, split_seq, lambda, cv_foldid, cell_stats, full_stat_df, est_plan] = main_ret
  
  importance_weights <- interaction_weights <- NULL
  if(importance_type=="interaction") {
    import_ret = do.call(get_feature_interactions, c(main_params, list(partition=partition), extra_params))
    importance_weights = import_ret$delta_k
    interaction_weights = import_ret$delta_k12
  }
  else if(importance_type %in% c("single", "fast")) {
    importance_weights = do.call(get_importance_weights, c(main_params, list(partition=partition, type=importance_type), 
                                                           extra_params))
  }
  
  tn = Sys.time()
  td = tn-t0
  if(verbosity>0) cat(paste("Entire Fit-Estimation Duration: ", format(as.numeric(td)), " ", attr(td, "units"), "\n"))
  
  return(structure(list(partition=partition, 
                        cell_stats=cell_stats, 
                        importance_weights=importance_weights, 
                        interaction_weights=interaction_weights, 
                        has_d=!is.null(d),
                        lambda=lambda, 
                        is_obj_val_seq=is_obj_val_seq,
                        complexity_seq=complexity_seq,
                        partition_i=partition_i,
                        split_seq=split_seq,
                        index_tr=index_tr,
                        cv_foldid=cv_foldid,
                        varnames=names(X),
                        honest=honest,
                        est_plan=est_plan,
                        full_stat_df=full_stat_df,
                        m_mode=m_mode,
                        M=M),
                   class = c("estimated_partition")))
}

#' is object estimated_partition
#' 
#' Tests whether the object is an \code{estimated_partition} object.
#'
#' @param x an R object
#'
#' @return True if x is an estimated_partition
#' @export
#' @describeIn fit_estimate_partition is estimated_partition
is_estimated_partition <- function(x) {
  inherits(x, "estimated_partition")
} 

#' @describeIn num_cells estimated_partition
#' @export
num_cells.estimated_partition <- function(obj) {
  return(num_cells(obj$partition))
}

# Inherited params: y, X, d
#' Change the complexity of a fit_estimate_partition
#' 
#' Change the complexity level of the partition and re-estimate cell statistics. 
#' If you have a minimal estimated_partition then you need to pass in the other params.
#' 
#' Note: doesn't update the importance weights
#' 
#' @inheritSection fit_partition Multiple estimates
#'
#' @param fit estimated_partition 
#' @param partition_i partition_i - 1 is the last include in split_seq included in new partition
#' @inheritParams fit_partition
#' @param index_tr Split between train and estimate samples (default is to get from \code{fit})
#' @param split_seq sequential list of splits (default is to get from \code{fit})
#'
#' @return updated estimated_partition
#' @export 
change_complexity <- function(fit, y, X, d=NULL, partition_i, index_tr = fit$index_tr,
                              split_seq = fit$split_seq, est_plan=fit$est_plan) {
  #TODO: Refactor checks from fit_estimation_partition and put them here
  X = ensure_good_X(X)
  X = update_names_m(X)
  list[y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est] = split_sample_m(y, X, d, index_tr)
  
  fit$partition = partition_from_split_seq(split_seq, fit$partition$X_range, 
                                           varnames=fit$partition$varnames, max_include=partition_i-1)
  list[cell_factor, stats] = est_cell_stats(y_es, X_es, d_es, fit$partition, est_plan=est_plan)
  fit$cell_stats = stats
  
  return(fit)
}


#' Get descriptive data.frame
#'
#' Get information for each cell
#'
#' @inheritParams get_desc_df
#' @param  import_order Whether should use importance ordering
#'   (most important on the left) or input ordering (default) for features. Rows
#'   will be ordered so that the right-most will change most frequently.
#' 
#' 
#' @return data.frame with columns: partitioning columns, {N_est, param_ests,
#'   pval} per estimate
#' @export
get_desc_df.estimated_partition <- function(obj, cont_bounds_inf=TRUE, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, unsplit_cat_star=TRUE, import_order=FALSE, ...) {
  M = obj$M
  stats = obj$cell_stats[c(F, rep(T,M), rep(T,M), rep(F,M),rep(F,M), rep(F,M), rep(F,M), rep(T,M), rep(F,M), rep(F,M))]
  part_df = get_desc_df(obj$partition, cont_bounds_inf=cont_bounds_inf, do_str=do_str, drop_unsplit=drop_unsplit, digits=digits, unsplit_cat_star=unsplit_cat_star)
  
  imp_weights = obj$importance_weights
  if(drop_unsplit) {
    imp_weights = imp_weights[obj$partition$nsplits_by_dim>0]
  }
  if(import_order) {
    part_df = part_df[, order(imp_weights, decreasing=TRUE)]
    part_df = part_df[do.call("order",part_df),] #re-sorts so that rightward changes most frequently
  } 
  
  return(cbind(part_df, stats))
}

# Inherited params: do_str, drop_unsplit, digits, import_order
#' Print estimated_partition
#' 
#' Print a summary of the estimated partition. Uses \code{\link{get_desc_df}}
#'
#' @param x estimated_partition object 
#' @inheritParams get_desc_df
#' @param import_order Whether should use importance ordering
#'   (most important on the left) or input ordering (default) for features.
#' @param ... Additional arguments. These will be passed to print.data.frame
#'
#' @return string (and displayed)
#' @export 
print.estimated_partition <- function(x, do_str=TRUE, drop_unsplit=TRUE, digits=NULL, import_order=FALSE, ...) {
  return(print(get_desc_df(x, do_str, drop_unsplit, digits, import_order=import_order), 
               digits=digits, ...))
}

#libs required and suggested. Use if sourcing directly. 
#lapply(lib_list, require, character.only = TRUE)
#CausalGrid_libs <- function(required=TRUE, suggested=TRUE, load_Rcpp=FALSE) {
#  lib_list = c()
#  if(required) lib_list = c(lib_list, "caret", "gsubfn", "assertthat")
#  if(load_Rcpp) lib_list = c(lib_list, "Rcpp")
#  if(suggested) lib_list = c(lib_list, "ggplot2", "glmnet", "gglasso", "parallel", "pbapply", "ranger")
#  #Build=Rcpp. Full dev=testthat, knitr, rmarkdown, renv, rprojroot
#  return(lib_list)
#}

#' Test for any sign effect
#' 
#' Accounting for multiple testing, is there a group with statistically significant specified sign effect 
#'
#' @param obj an \code{estimated_partition} object
#' @param check_negative If true, check for a negative. If false, check for positive. 
#' @param method one of c("fdr", "sim_mom_ineq"). \code{fdr} is conservative. 
#'               \code{sim_mom_ineq} Need samples sizes to sufficiently large so that the effects are normally distributed
#' @param alpha alpha
#' @param n_sim n_sim
#'
#' @return list(are_any= boolean of whether effect is negative)
#' @export
test_any_sign_effect <- function(obj, check_negative=T, method="fdr", alpha=0.05, n_sim=500) {
  #TODO: could also
  assert_that(method %in% c("fdr", "sim_mom_ineq"), msg="Unknown method (must be one of fdr or sim_mom_ineq.")
  if(method=="fdr") {
    assert_that(obj$has_d, alpha>0, alpha<1, msg="Testing significance requires d and alpha in (0,1)")
    dofs = obj$cell_stats[["N_est"]] - obj$est_plan$dof
    pval_right= pt(obj$cell_stats$tstats, df=dofs, lower.tail=FALSE) #right-tailed. Checking for just a positive effect (H_a is "greater")
    pval_left = pt(obj$cell_stats$tstats, df=dofs, lower.tail=TRUE) #left-tailed. Checking for just a negative effect (H_a is "less")
    pval1s = if(check_negative) pval_left else pval_right
    pval1s_fdr = p.adjust(pval1s, "BH")
    are_any = sum(pval1s_fdr<alpha) > 0
    return(list(are_any=are_any, pval1s=pval1s, pval1s_fdr=pval1s_fdr))
  }
  else {
    N_cell = nrow(obj$cell_stats)
    te_se = sqrt(obj$cell_stats[["var_ests"]])
    tstat_ext = if(check_negative) min(obj$cell_stats[["tstats"]]) else max(obj$cell_stats[["tstats"]])
    sim_tstat_exts = rep(NA, n_sim)
    for(s in 1:n_sim) {
      sim_te = rnorm(N_cell, mean=0, sd=te_se)
      sim_tstat_exts[s] =  if(check_negative) min(sim_te/te_se) else max(sim_te/te_se)
    }
    if(check_negative) {
      are_any = sum(sim_tstat_exts < quantile(sim_tstat_exts, alpha)) > 0
    }
    else {
      are_any = sum(sim_tstat_exts > quantile(sim_tstat_exts, 1-alpha)) > 0
    }
  }
  return(list(are_any=are_any))
}


# Inherited params: y, X, d
#' Estimate parameters across the cells
#' 
#' Estimate the parameters (including standard errors) across the cells in the sample.
#' 
#' @inheritSection fit_partition Multiple estimates
#'
#' @param partition (Optional, need this or cell_factor) partitioning returned from fit_estimate_partition
#' @param cell_factor (Optional, need this or partition) 
#' @param estimator_var (Optional) a function with signature list(param_est, var_est) = function(y, d) 
#'                      (where if no d then can pass in null). If NULL then will choose between built-in 
#'                      mean-estimator and scalar_te_estimator
#' @param est_plan Estimator plan
#' @param alpha Significance threshold
#' @inheritParams fit_partition
#' 
#' @return list
#' \item{cell_factor}{Factor with levels for each cell for X. Length N.}
#' \item{stats}{data.frame(cell_i, N_est, param_ests, var_ests, tstats, pval, ci_u, ci_l, p_fwer, p_fdr)}
#' @export
est_cell_stats <- function(y, X, d=NULL, partition=NULL, cell_factor=NULL, estimator_var=NULL, 
                           est_plan=NULL, alpha=0.05) {
  list[M, m_mode, N_tr, K] = get_sample_type(y, X, d, checks=TRUE)
  X = ensure_good_X(X)
  
  if(is.null(est_plan)) {
    if(is.null(estimator_var)) est_plan = gen_simple_est_plan(has_d=!is.null(d))
    else est_plan = simple_est(NULL, estimator_var)
  }
  if(is.null(cell_factor)) {
    cell_factor = predict(partition, X)
  }
  list[lvls, n_cells] = lcl_levels(cell_factor)
  param_ests = matrix(NA, nrow=n_cells, ncol=M)
  var_ests = matrix(NA, nrow=n_cells, ncol=M)
  cell_sizes = matrix(NA, nrow=n_cells, ncol=M)
  for(cell_i in 1:n_cells) {
    list[y_cell, d_cell, X_cell, N_l] <- get_cell(y, X, d, cell_factor, cell_i, lvls)
    cell_sizes[cell_i,] = N_l
    list[param_ests_c, var_ests_c] = Param_Est_m(est_plan, y_cell, d_cell, X_cell, sample="est", ret_var=TRUE, m_mode=m_mode)
    param_ests[cell_i,] = param_ests_c
    var_ests[cell_i,] = var_ests_c
  }
  dofs = t(t(cell_sizes) - get_dofs(est_plan, M, m_mode)) #subtract dofs from each row
  colnames(cell_sizes) = if(M==1) "N_est" else paste("N_est", 1:M, sep="")
  colnames(param_ests) = if(M==1) "param_ests" else paste("param_ests", 1:M, sep="")
  colnames(var_ests) = if(M==1) "var_ests" else paste("var_ests", 1:M, sep="")
  base_df = cbind(data.frame(cell_i=1:n_cells), cell_sizes, param_ests, var_ests)
  list[stat_df, pval] = exp_stats(base_df, param_ests, var_ests, dofs, alpha=alpha, M)
  p_fwer = matrix(p.adjust(pval, "hommel"), ncol=ncol(pval)) #slightly more powerful than "hochberg". Given indep these are better than "bonferroni" and "holm"
  p_fdr = matrix(p.adjust(pval, "BH"), ncol=ncol(pval)) #ours are independent so don't need "BY"
  colnames(p_fwer) = if(M==1) "p_fwer" else paste("p_fwer", 1:M, sep="")
  colnames(p_fdr) = if(M==1) "p_fdr" else paste("p_fdr", 1:M, sep="")
  stat_df = cbind(stat_df, p_fwer, p_fdr)
  return(list(cell_factor=cell_factor, stats=stat_df))
}


# Inherited params: y, X, d, est_plan, alpha
#' Estimate stats on the full samples
#' 
#' Estimates the parameters on the full and \code{est} samples
#' 
#' @inheritSection fit_partition Multiple estimates
#'
#' @param y_es y for \code{est} sample. Omit if providing \code{index_tr}.
#' @param X_es X for \code{est} sample. Omit if providing \code{index_tr}.
#' @param d_es d for \code{est} sample. Omit if providing \code{index_tr}.
#' @param index_tr Indexes of the \code{train} sample. Can be omitted if providing \code{y_es}, \code{X_es}, \code{d_es}.
#' @inheritParams est_cell_stats
#'
#' @return Stats df
#' @export
est_full_stats <- function(y, X, d, est_plan, y_es=NULL, X_es=NULL, d_es=NULL, index_tr=NULL, alpha=0.05) {
  list[M, m_mode, N_tr, K] = get_sample_type(y, X, d, checks=TRUE)
  X = ensure_good_X(X)
  
  if(is.null(y_es)) {
    list[y_tr, y_es, X_tr, X_es, d_tr, d_es, N_est] = split_sample_m(y, X, d, index_tr)
  }
  N_es = nrow_m(X_es, M)
  full_Ns = rbind(N_tr, N_es)
  colnames(full_Ns) = if(M==1) "N_est" else paste("N_est", 1:M, sep="")
  list[full_param_ests_all, full_var_ests_all] = Param_Est_m(est_plan, y, d, X, sample="est", ret_var=TRUE, m_mode=m_mode)
  list[full_param_ests_es, full_var_ests_es] = Param_Est_m(est_plan, y_es, d_es, X_es, sample="est", ret_var=TRUE, m_mode=m_mode)
  M = length(full_param_ests_all)
  full_param_ests = rbind(full_param_ests_all, full_param_ests_es)
  colnames(full_param_ests) = if(M==1) "param_ests" else paste("param_ests", 1:M, sep="")
  full_var_ests = rbind(full_var_ests_all, full_var_ests_es)
  colnames(full_var_ests) = if(M==1) "var_ests" else paste("var_ests", 1:M, sep="")
  base_df = cbind(data.frame(sample=c("all", "est")), full_Ns, full_param_ests, full_var_ests)
  dofs = t(t(full_Ns) - get_dofs(est_plan, M, m_mode)) #subtract dofs from each row
  list[full_stat_df, pval] = exp_stats(base_df, full_param_ests, full_var_ests, dofs, alpha=alpha, M)
  return(full_stat_df)
}


#' Generate predicted estimates per observations
#' 
#' Predicted unit-level treatment effect or outcome
#'
#' @param object estimated_partition object
#' @param new_X new X
#' @param new_d new d. Required for type="outcome"
#' @param type "effect" or "outcome" (currently not implemented)
#' @param ... Additional arguments. Unused.
#'
#' @return predicted treatment effect
#' @export
predict.estimated_partition <- function(object, new_X, new_d = NULL, type = "effect", ...) {
  #TODO: for mode 1 &2 maybe return a matrix rather than list
  
  new_X = ensure_good_X(new_X)
  new_X_range = get_X_range(new_X)
  
  cell_factor = predict(object$partition, new_X, new_X_range)
  M = object$M
  
  if(M==1) {
    N=nrow(new_X)
    cell_factor_df = data.frame(id=1:N, cell_i = as.integer(cell_factor))
    m_df = merge(cell_factor_df, object$cell_stats)
    m_df = m_df[order(m_df[["id"]]), ]
    return(m_df[["param_ests"]])
  }
  N = nrow_m(new_X, M)
  rets = list()
  for(m in 1:M) {
    cell_factor_df = data.frame(id=1:N[m], cell_i = as.integer(cell_factor[[m]]))
    m_df = merge(cell_factor_df, object$cell_stats)
    m_df = m_df[order(m_df[["id"]]), ]
    rets[[m]] = m_df[["param_ests"]]
  }
  return(rets)
}


# Inherited params: y, X, d
#' Evaluate the MSE_hat objective function
#' 
#' Evaluate the MSE_hat objective function over the cells for the sample
#'
#' @param y_tr y_tr
#' @param X_tr X_tr
#' @param d_tr d_tr
#' @param y_te y_te
#' @param X_te X_te
#' @param d_te d_te
#' @param N_est Size of estimation sample. Unused for this objective function.
#' @param partition Grid partition. Pass in this or \code{cell_factor}.
#' @param cell_factor_tr Factor for cells for each observation. Pass in this or \code{partition}. 
#' @param est_plan Estimation plan. If this and  \code{estimator} are null, then one is created as  \code{gen_simple_est_plan(has_d=!is.null(d))}
#' @param estimator If not passing in \code{est_plan}, can pass in this estimation routine and one is created using \code{\link{simple_est}}.
#' @param debug T/F whether we are in debug mode (and print out more info)
#' @param warn_on_error  T/F for whether to display a warning when estimation fails (NA values returned)
#' @param sample Passed to \code{\link{est_params}}

#' @return \code{c(val, N_cell_empty, N_cell_error)}
#' @export
#' @keywords internal
eval_mse_hat <-function(y_tr, X_tr, d_tr, y_te=NULL, X_te=NULL, d_te=NULL, N_est, partition=NULL, cell_factor_tr=NULL, cell_factor_te=NULL, est_plan=NULL, estimator=NULL, debug=FALSE, 
                        warn_on_error=FALSE, sample="trtr", ...) {
  incl_te = !is.null(y_te)
  if(is.null(est_plan)) {
    if(is.null(estimator)) est_plan = gen_simple_est_plan(has_d=!is.null(d_tr))
    else est_plan = simple_est(estimator, estimator)
  }
  
  if(is.null(cell_factor_tr)) {
    cell_factor_tr = predict(partition, X_tr)
    if(incl_te) cell_factor_te = predict(partition, X_te)
  }
  list[M, m_mode, N_tr, K] = get_sample_type(y_tr, X_tr, d_tr, checks=FALSE)
  if(incl_te) list[M, m_mode, N_te, K] = get_sample_type(y_te, X_te, d_te, checks=FALSE)
  list[lvls, n_cells] = lcl_levels(cell_factor_tr)
  cell_contribs = rep(0, n_cells)
  N_eff = 0
  N_cell_empty = 0
  N_cell_error = 0
  for(cell_i in 1:n_cells) {
    list[y_tr_cell, d_tr_cell, X_tr_cell, N_tr_l] <- get_cell(y_tr, X_tr, d_tr, cell_factor_tr, cell_i, lvls)
    if(incl_te) list[y_te_cell, d_te_cell, X_te_cell, N_te_l] <- get_cell(y_te, X_te, d_te, cell_factor_te, cell_i, lvls)
    if(any(N_tr_l==0) || (incl_te && any(N_te_l==0))) {
      N_cell_empty = N_cell_empty+1
      next
    }
    list[param_est_tr] = Param_Est_m(est_plan, y_tr_cell, d_tr_cell, X_tr_cell, sample=sample, ret_var=FALSE, m_mode)
    if(incl_te) list[param_est_te] = Param_Est_m(est_plan, y_te_cell, d_te_cell, X_te_cell, sample=sample, ret_var=FALSE, m_mode)
    if(!all(is.finite(param_est_tr)) || (incl_te && !all(is.finite(param_est_te)))) {
      N_cell_error = N_cell_error+1
      msg = paste("Failed estimation: (N_tr_l=", N_tr_l, ")") #if printing var(d_cell), remember it could be a list
      if(incl_te) msg = paste(msg, "(N_te_l", N_te_l, ")")
      if(warn_on_error) warning(paste(msg,"\n"))
      next
    }
    if(incl_te) t1_mse = sum(N_te_l*(2*param_est_tr*param_est_te - param_est_tr^2))
    else t1_mse = sum(N_tr_l*param_est_tr^2)
    N_eff = N_eff + ifelse(incl_te, sum(N_te_l), sum(N_tr_l))
    cell_contribs[cell_i] = t1_mse
    if(debug) { 
      if(incl_te) print(paste("cell_i=", cell_i, "; N_tr_l=", N_tr_l, "; param_est_tr=", param_est_tr, "; N_te_l=", N_te_l, "; param_est_te=", param_est_te))
      else print(paste("cell_i=", cell_i, "; N_tr_l=", N_tr_l, "; param_est_tr=", param_est_tr))
    }
  }
  val = -1/N_eff*sum(cell_contribs) # Use N_eff to remove from average given errors
  if(debug) print(paste("cell sums", val))
  return(c(val, N_cell_empty, N_cell_error))
}

#Version of the above if we've already estimated the table.
# Doesn't do any error checking
eval_mse_hat_tbl <- function(train_tbl, te_tbl=NULL, incl_base_te=TRUE) {
  if(is.null(te_tbl)) {
    val = sum(train_tbl['N_est']*train_tbl['param_ests']^2)
    val = -1/(sum(train_tbl['N_est']))*val
  }
  else {
    if(incl_base_te) {
      val = sum(te_tbl['N_est']*(te_tbl['param_ests'] - train_tbl['param_ests'])^2)
      }
    else {
      val = sum(te_tbl['N_est']*(-2*train_tbl['param_ests']*te_tbl['param_ests'] + train_tbl['param_ests']^2))
    }
    val = 1/(sum(te_tbl['N_est']))*val
  }
  return(val)
}

#TODO: Adapt to allow for separate sample evaluation (used for CV). This is done for eval_mse_hat
# If empty and err cells will be removed from the calculation, but counts of these returned
# estimator_var: takes sample for cell and returns coefficient estimate and estimated variance of estimate
# eval_emse_hat<-function(y, X , d, N_est, partition=NULL, cell_factor=NULL, estimator_var=NULL, debug=FALSE, 
#                     warn_on_error=FALSE, alpha=NULL, est_plan=NULL, sample="trtr") {
#   if(is.null(est_plan)) {
#     if(is.null(estimator_var)) est_plan = gen_simple_est_plan(has_d=!is.null(d))
#     else est_plan = simple_est(estimator_var, estimator_var)
#   }
#   if(is.null(cell_factor)) {
#     cell_factor = predict(partition, X)
#   }
#   if(!is.null(alpha)) {
#     stopifnot(alpha>=0 & alpha<=1)
#   }
#   list[M, m_mode, N_tr, K] = get_sample_type(y, X, d, checks=FALSE)
#   list[lvls, n_cells] = lcl_levels(cell_factor)
#   cell_contribs1 = rep(0, n_cells)
#   cell_contribs2 = rep(0, n_cells)
#   N_eff = 0
#   N_cell_empty = 0
#   N_cell_err = 0
#   for(cell_i in 1:n_cells) {
#     list[y_cell, d_cell, X_cell, N_l] <- get_cell(y, X, d, cell_factor, cell_i, lvls)
#     if(any(N_l==0)) {
#       N_cell_empty = N_cell_empty+1
#       next
#     }
#     
#     list[param_est, var_est] = Param_Est_m(est_plan, y_cell, d_cell, X_cell, sample=sample, ret_var=TRUE, m_mode) 
#     if(!all(is.finite(param_est)) || !all(is.finite(var_est))) {
#       N_cell_err = N_cell_err+1
#       msg = paste("Failed estimation: (N_l=", N_l, ", param_est=", param_est, ", var_est=", var_est,
#                   ifelse(!is.null(d), paste(", var_d=", var(d_cell)) , ""),
#                   ")\n")
#       if(warn_on_error) warning(msg)
#       next
#     }
#     t1_mse = sum(N_l*param_est^2)
#     t2_var = sum(N_l*var_est)
#     N_eff = N_eff + sum(N_l)
#     cell_contribs1[cell_i] = t1_mse
#     cell_contribs2[cell_i] = t2_var
#     if(debug) print(paste("N=", N_tr, "; cell_i=", cell_i, "; N_l=", N_l, "; param_est=", param_est, 
#                           "; var_est=", var_est))
#   }
#   t1_mse = -1/N_eff*sum(cell_contribs1)
#   t2_var = (1/N_eff + 1/N_est)*sum(cell_contribs2) 
#   val = if(is.null(alpha)) t1_mse + t2_var else alpha*t1_mse + (1-alpha)*t2_var
#   if(debug) print(paste("cell sums", val))
#   if(!is.finite(val)) stop("Non-finite val")
#   return(c(val, N_cell_empty, N_cell_err))
# }

get_cell <- function(y, X, d, cell_factor, cell_i, lvls) {
  list[M, m_mode, N_tr, K] = get_sample_type(y, X, d)
  if(is_sep_sample(X)) {
    y_cell = d_cell = X_cell = list()
    N_l = rep(0, M)
    for(m in 1:M) {
      cell_ind = cell_factor[[m]]==lvls[[m]][cell_i]
      y_cell[[m]] = y[[m]][cell_ind]
      d_cell[[m]] = d[[m]][cell_ind]
      X_cell[[m]] = X[[m]][cell_ind, , drop=FALSE]
      N_l[m] = sum(cell_ind)
    }
  }
  else {
    cell_ind = cell_factor==lvls[cell_i]
    N_l = if(M==1) sum(cell_ind) else rep(sum(cell_ind), M)
    y_cell = if(is_vec(y)) y[cell_ind] else y[cell_ind, , drop=FALSE]
    d_cell = if(is_vec(d)) d[cell_ind] else d[cell_ind, , drop=FALSE]
    X_cell = X[cell_ind, , drop=FALSE]
  }
  return(list(y_cell, d_cell, X_cell, N_l))
}

#Split the sample
expand_tr_split_info<-function(tr_split, N_tr, m_mode) {
  if(length(tr_split)==1)
    return(gen_split_m(N_tr, tr_split, m_mode==1))
  return(tr_split)
}

lcl_levels <- function(cell_factor) {
  if(!is.list(cell_factor)) {
    lvls = levels(cell_factor)
    return(list(lvls, length(lvls)))
  }
  lvls = lapply(cell_factor, levels)
  return(list(lvls, length(lvls[[1]])))
}

get_dofs <- function(est_plan, M, m_mode) {
  if(m_mode==0) 
    return(est_plan$dof)
  
  if(is_sep_estimators(m_mode))
    return(sapply(est_plan, function(plan) plan$dof))
  
  #Only 1 estimator but multiple d, so make sure right length
  dof = est_plan$dof
  if(length(dof)==1) dof=rep(dof, M)
  return(dof)
}

exp_stats <- function(stat_df, param_ests, var_ests, dofs, alpha=0.05, M) {
  tstats = param_ests/sqrt(var_ests)
  colnames(tstats) = if(M==1) "tstats" else paste("tstats", 1:M, sep="")
  t.half.alpha = qt(1-alpha/2, df=dofs)*sqrt(var_ests)
  ci_u = param_ests + t.half.alpha
  colnames(ci_u) = if(M==1) "ci_u" else paste("ci_u", 1:M, sep="")
  ci_l = param_ests - t.half.alpha
  colnames(ci_l) = if(M==1) "ci_l" else paste("ci_l", 1:M, sep="")
  pval = 2*pt(abs(tstats), df=dofs, lower.tail=FALSE)
  colnames(pval) = if(M==1) "pval" else paste("pval", 1:M, sep="")
  stat_df = cbind(stat_df, tstats, ci_u, ci_l, pval)
  #pval_right= pt(tstats, df=dofs, lower.tail=FALSE) #right-tailed. Checking for just a positive effect (H_a is "greater")
  #pval_left = pt(tstats, df=dofs, lower.tail=TRUE) #left-tailed. Checking for just a negative effect (H_a is "less")
  return(list(stat_df, pval))
}

get_importance_weights_full_k <- function(k_i, to_compute, X_d, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, breaks_per_dim, verbosity, ...) {
  if(verbosity>0) cat(paste("Feature weight > ", k_i, "of", length(to_compute),"\n"))
  k = to_compute[k_i]
  X_k = drop_col_k_m(X_d, k)
  X_tr_k = drop_col_k_m(X_tr, k)
  X_es_k = drop_col_k_m(X_es, k)
  main_ret = fit_estimate_partition_int(X_k, y, d, X_tr_k, y_tr, d_tr, y_es, X_es_k, d_es, X_range[-k], breaks_per_dim=breaks_per_dim[-k], verbosity=verbosity, nsplits_k_warn_limit=NA, ...)
  nk_val = eval_mse_hat(y_es, X_es_k, d_es, partition=main_ret$partition, est_plan=main_ret$est_plan, sample="est")[1] #use oos version instead of main_ret$is_obj_val_seq[partition_i]
  return(list(nk_val, main_ret$partition$nsplits_by_dim))
}

# Just use mse_hat as we're working not on the Tr sample, but the est sample
# The ... params are passed to get_importance_weights_full_k -> fit_estimate_partition_int
# There's an undocumented "fast" version. Not very great as assings 0 to any feature not split on
get_importance_weights <- function(X, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, breaks_per_dim, partition, est_plan, type, verbosity, pr_cl, ...) {
  if(verbosity>0) cat("Feature weights: Started.\n")
  K = length(X_range)
  if(sum(partition$nsplits_by_dim)==0) return(rep(0, K))
  full_val = eval_mse_hat(y_es, X_es, d_es, partition = partition, est_plan=est_plan, sample="est")[1]
  
  if(K==1) {
    null_val = eval_mse_hat(y_es, X_es, d_es, partition = grid_partition(partition$X_range, partition$varnames), est_plan=est_plan, sample="est")[1]
    if(verbosity>0) cat("Feature weights: Finished.\n")
    return(null_val - full_val)
  }
  
  if("fast"==type) {
    new_vals = rep(0, K)
    factors_by_dim = get_factors_from_partition(partition, X_es)
    for(k in 1:K) {
      if(partition$nsplits_by_dim[k]>0) {
        cell_factor_nk = gen_holdout_interaction_m(factors_by_dim, k, is_sep_sample(X_tr))
        new_vals[k] = eval_mse_hat(y_es, X_es, d_es, cell_factor_tr = cell_factor_nk, est_plan=est_plan, sample="est")[1]
      }
    }
    if(verbosity>0) cat("Feature weights: Finished.\n")
    return(new_vals - full_val)
  }
  
  #if("full"==type)
  new_vals = rep(full_val, K)
  to_compute = which(partition$nsplits_by_dim>0)

  params = c(list(to_compute, X_d=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, breaks_per_dim=breaks_per_dim, est_plan=est_plan, verbosity=verbosity-1), 
             list(...))

  rets = my_apply(1:length(to_compute), get_importance_weights_full_k, verbosity==1 || !is.null(pr_cl), pr_cl, params)
  for(k_i in 1:length(to_compute)) {
    k = to_compute[k_i]
    new_vals[k] = rets[[k_i]][[1]]
  }
  if(verbosity>0) cat("Feature weights: Finished.\n")
  return(new_vals - full_val)
}

get_feature_interactions_k12 <- function(ks_i, to_compute, X_d, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, breaks_per_dim, verbosity, ...) {
  if(verbosity>0) cat(paste("Feature interaction weight > ", ks_i, "of", length(to_compute),"\n"))
  ks = to_compute[[ks_i]]
  X_k = drop_col_k_m(X_d, ks)
  X_tr_k = drop_col_k_m(X_tr, ks)
  X_es_k = drop_col_k_m(X_es, ks)
  main_ret = fit_estimate_partition_int(X_k, y, d, X_tr_k, y_tr, d_tr, y_es, X_es_k, d_es, X_range[-ks], breaks_per_dim=breaks_per_dim[-ks], verbosity=verbosity, nsplits_k_warn_limit=NA, ...)
  nk_val = eval_mse_hat(y_es, X_es_k, d_es, partition=main_ret$partition, est_plan=main_ret$est_plan, sample="est")[1] #use oos version instead of main_ret$is_obj_val_seq[partition_i]
  return(nk_val)
  
}

get_feature_interactions <- function(X, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, X_range, breaks_per_dim, partition, est_plan, verbosity, pr_cl, ...) {
  
  if(verbosity>0) cat("Feature weights: Started.\n")
  K = length(X_range)
  dnames = list(colnames(X), colnames(X))
  delta_k12 = matrix(as.integer(diag(rep(NA, K))), ncol=K, dimnames=dnames) #dummy for K<3 cases
  if(sum(partition$nsplits_by_dim)==0){ 
    if(verbosity>0) cat("Feature weights: Finished.\nFeature interaction weights: Started.\nFeature interaction interactions: Finished.\n")
    return(list(delta_k=rep(0, K), delta_k12=delta_k12))
  }
  full_val = eval_mse_hat(y_es, X_es, d_es, partition = partition, est_plan=est_plan, sample="est")[1]
  
  if(K==1) {
    null_val = eval_mse_hat(y_es, X_es, d_es, partition = grid_partition(partition$X_range, partition$varnames), est_plan=est_plan, sample="est")[1]
    if(verbosity>0) cat("Feature weights: Finished.\nFeature interaction weights: Started.\nFeature interaction interactions: Finished.\n")
    return(list(delta_k=null_val - full_val, delta_k12=delta_k12))
  }
  
  #compute the single-removed values (and keep around the nsplits from each new partition)
  new_val_k = rep(full_val, K)
  to_compute_k = which(partition$nsplits_by_dim>0)
  params = c(list(to_compute=to_compute_k, X_d=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, breaks_per_dim=breaks_per_dim, est_plan=est_plan, verbosity=verbosity-1), 
             list(...))
  rets_k = my_apply(1:length(to_compute_k), get_importance_weights_full_k, verbosity==1 || !is.null(pr_cl), pr_cl, params)
  for(k_i in 1:length(to_compute_k)) {
    k = to_compute_k[k_i]
    new_val_k[k] = rets_k[[k_i]][[1]]
  }
  delta_k = new_val_k - full_val
  if(K==2) {
    null_val = eval_mse_hat(y_es, X_es, d_es, partition = grid_partition(partition$X_range, partition$varnames), est_plan=est_plan, sample="est")[1]
    delta_k12 = matrix(null_val - full_val, ncol=2) + diag(rep(NA, K))
    colnames(delta_k12) = colnames(X)
    rownames(delta_k12) = colnames(X)
    if(verbosity>0) cat("Feature weights: Finished.\nFeature interaction weights: Started.\nFeature interaction interactions: Finished.\n")
    return(list(delta_k=delta_k, delta_k12=delta_k12))
  } 
  if(verbosity>0) cat("Feature weights: Finished.\n")
  
  
  #Compute the pair-removed values
  if(verbosity>0) cat("Feature interaction weights: Started.\n")
  new_val_k12 = matrix(full_val, ncol=K, nrow=K)
  to_compute = list()
  for(k1 in 1:(K-1)) {
    if(partition$nsplits_by_dim[k1]==0) {
      new_val_k12[k1,] = new_val_k
      new_val_k12[,k1] = new_val_k
    }
    else {
      k1_i = which(to_compute_k==k1)
      nsplits_by_dim_k1= rets_k[[k1_i]][[2]]
      for(k2 in (k1+1):K) {
        if(nsplits_by_dim_k1[k2-1]==0) { #nsplits_by_dim_k1 is missing k1 so drop k2 back one
          new_val_k12[k1,k2] = new_val_k[k1]
          new_val_k12[k2,k1] = new_val_k[k1]
        }
        else {
          to_compute = c(list(c(k1, k2)), to_compute)
        }
      }
    }
  }
  params = c(list(to_compute=to_compute, X_d=X, y=y, d=d, X_tr=X_tr, y_tr=y_tr, d_tr=d_tr, y_es=y_es, X_es=X_es, d_es=d_es, X_range=X_range, breaks_per_dim=breaks_per_dim, est_plan=est_plan, verbosity=verbosity-1), 
             list(...))
  rets_k12 = my_apply(1:length(to_compute), get_feature_interactions_k12, verbosity==1 || !is.null(pr_cl), pr_cl, params)
  for(ks_i in 1:length(to_compute)) {
    k1 = to_compute[[ks_i]][1]
    k2 = to_compute[[ks_i]][2]
    new_val = rets_k12[[ks_i]]
    new_val_k12[k1, k2] = new_val
    new_val_k12[k2, k1] = new_val
  }
  delta_k12 = t(t((new_val_k12 - full_val) - delta_k) - delta_k) + diag(rep(NA, K))
  colnames(delta_k12) = colnames(X)
  rownames(delta_k12) = colnames(X)
  if(verbosity>0) cat("Feature interaction interactions: Finished.\n")
  return(list(delta_k=delta_k, delta_k12=delta_k12))
}


fit_and_residualize <- function(est_plan, X_tr, y_tr, d_tr, cv_folds, y_es, X_es, d_es, verbosity, dim_cat) {
  est_plan = fit_on_train(est_plan, X_tr, y_tr, d_tr, cv_folds, verbosity=verbosity, dim_cat=dim_cat)
  list[y_tr, d_tr] = residualize(est_plan, y_tr, X_tr, d_tr, sample="tr")
  list[y_es, d_es] = residualize(est_plan, y_es, X_es, d_es, sample="est")
  return(list(est_plan, y_tr, d_tr, y_es, d_es))
}

# ... params are passed to fit_partition()
fit_estimate_partition_int <- function(X, y, d, X_tr, y_tr, d_tr, y_es, X_es, d_es, dim_cat, X_range, est_plan, honest, cv_folds, verbosity, M, m_mode, 
                                       alpha, ...) {
  K = length(X_range)
  obj_fn = if(honest) eval_emse_hat else eval_mse_hat
  
  list[nfolds, folds_ret, foldids] = expand_fold_info(y_tr, cv_folds, m_mode)
  
  if(!is.null(d))
    list[est_plan, y_tr, d_tr, y_es, d_es] = fit_and_residualize_m(est_plan, X_tr, y_tr, d_tr, foldids, y_es, X_es, d_es, m_mode, M, verbosity, dim_cat)
  
  if(verbosity>0) cat("Training partition on training set\n")
  fit_ret = fit_partition(y=y_tr, X=X_tr, d=d_tr, X_aux=X_es, d_aux=d_es, cv_folds=foldids, verbosity=verbosity, 
                          X_range=X_range, obj_fn=obj_fn, est_plan=est_plan, ...)
  list[partition, is_obj_val_seq, complexity_seq, partition_i, partition_seq, split_seq, lambda, cv_foldid] = fit_ret
  
  if(verbosity>0) cat("Estimating cell statistics on estimation set\n")
  list[cell_factor, cell_stats] = est_cell_stats(y_es, X_es, d_es, partition, est_plan=est_plan, alpha=alpha)
  
  full_stat_df = est_full_stats(y, X, d, est_plan, y_es=y_es, X_es=X_es, d_es=d_es)
  
  return(list(partition=partition, is_obj_val_seq=is_obj_val_seq, complexity_seq=complexity_seq, partition_i=partition_i, partition_seq=partition_seq, split_seq=split_seq, lambda=lambda, cv_foldid=cv_foldid, cell_stats=cell_stats, full_stat_df=full_stat_df, est_plan=est_plan))
}
